# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A basic MNIST example using JAX with the mini-libraries stax and optimizers.

The mini-library jax.example_libraries.stax is for neural network building, and
the mini-library jax.example_libraries.optimizers is for first-order stochastic
optimization.
"""


import time
import itertools

import numpy.random as npr
import jax
import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu, LogSoftmax
from examples import datasets
from geesibling.adapters.jax import parallelize, device_config, DeviceType


def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)


init_random_params, predict = stax.serial(
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(1024),
    Relu,
    Dense(10),
    LogSoftmax,
)

if __name__ == "__main__":
    rng = random.PRNGKey(0)

    step_size = 0.001
    num_epochs = 1000
    batch_size = 128
    momentum_mass = 0.9

    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield train_images[batch_idx], train_labels[batch_idx]

    batches = data_stream()

    opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

    @parallelize(
        devices=device_config(
            {
                "gpu:1": {
                    "type": DeviceType.gpu,
                    "memory": 32 * 1024 * 1024 * 1024,
                    "free_memory": 32 * 1024 * 1024 * 1024,
                    "execute_time": 0,
                },
                "gpu:0": {
                    "type": DeviceType.gpu,
                    "memory": 22 * 1024 * 1024 * 1024,
                    "free_memory": 22 * 1024 * 1024 * 1024,
                    "execute_time": 0,
                },
            }
        ),
        policy="sgp",
    )
    # @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    _, init_params = init_random_params(rng, (-1, 28 * 28))
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    print("\nStarting training...")
    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            opt_state = update(next(itercount), opt_state, next(batches))
        epoch_time = time.time() - start_time
        with jax.default_device(jax.devices("cpu")[0]):
            params = jax.device_put(get_params(opt_state), jax.devices("cpu")[0])
            train_acc = accuracy(params, (train_images, train_labels))
            test_acc = accuracy(params, (test_images, test_labels))
            print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
            print(f"Training set accuracy {train_acc}")
            print(f"Test set accuracy {test_acc}")
